import torch
from torch import nn, sigmoid
from torch.utils.data import Dataset
import numpy as np


class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def init_weights_and_bias(self, weights, bias, device):
        """
        This function is used to initialize the weights and bias of the LR model
        Args:
            weights: provided weights values
            bias: provided bias values
            device: device on which the model is to be stored

        """
        weights = np.reshape(weights, (-1, weights.shape[0]))
        self.linear.weight.data = torch.Tensor(weights).to(device)
        self.linear.bias.data = torch.Tensor(np.asarray(bias)).to(device)

    def forward(self, x):
        out = sigmoid(self.linear(x))
        return out

    def compute_l1_loss(self, w):
        """
        This function is used to add l1 loss to the lr model

        """
        return torch.abs(w).sum()


class AR_Dataset(Dataset):
    """Students Performance dataset."""

    def __init__(self, X, y):
        """
        Initializes instance of class AR_Dataset.

        """

        # Save target and true labels
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        # Convert idx from tensor to list due to pandas bug (that arises when using pytorch's random_split)
        if isinstance(idx, torch.Tensor):
            idx = idx.tolist()

        return [self.X[idx], self.y[idx]]
